{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.datasets import mnist \n",
    "(X_train, y_train), (X_test, y_test) = mnist.load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(60000, 28, 28)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def data_shift(data_s, direc='u'):\n",
    "    ret = np.zeros((len(data_s), 784))\n",
    "    size = len(data_s)\n",
    "    if direc == 'u':\n",
    "        for i in range(size):\n",
    "            trans_data = np.append(data_s[i][1:,:], data_s[i][0:1,:],axis=0)\n",
    "            ret[i] = trans_data.reshape(1, -1)\n",
    "    elif direc == 'd':\n",
    "        for i in range(size):\n",
    "            trans_data = np.append(data_s[i][-1:,:], data_s[i][:-1,:],axis=0)\n",
    "            ret[i] = trans_data.reshape(1, -1)\n",
    "    elif direc == 'l':\n",
    "        for i in range(size):\n",
    "            trans_data = np.append(data_s[i][:,1:], data_s[i][:,0:1],axis=1)\n",
    "            ret[i] = trans_data.reshape(1, -1)\n",
    "    elif direc == 'r':\n",
    "        for i in range(size):\n",
    "            trans_data = np.append(data_s[i][:,-1:], data_s[i][:,:-1],axis=1)\n",
    "            ret[i] = trans_data.reshape(1, -1)\n",
    "    return ret\n",
    "\n",
    "X_trainu = data_shift(X_train, 'u')\n",
    "X_traind = data_shift(X_train, 'd')\n",
    "X_trainl = data_shift(X_train, 'l')\n",
    "X_trainr = data_shift(X_train, 'r')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(60000, 784)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_trainu.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(300000, 784)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train_A = np.concatenate((X_train.reshape(60000,784), X_trainu, X_traind, X_trainl, X_trainr), axis = 0)\n",
    "X_train_A.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(60000,)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(300000,)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train_A = np.concatenate((y_train, y_train, y_train, y_train, y_train), axis=0)\n",
    "y_train_A.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neighbors import KNeighborsClassifier as KNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "ps= [0.96340257 0.95450716 0.98216056 0.96442688 0.9762151  0.96528555\n",
      " 0.98130841 0.96108949 0.98809524 0.95626243] 0.9692753386570571\n"
     ]
    }
   ],
   "source": [
    "# 数据未增广\n",
    "knn_model = KNN()\n",
    "knn_model.fit(X_train.reshape(-1,784), y_train)\n",
    "y_pred = knn_model.predict(X_test.reshape(-1,784))\n",
    "\n",
    "# 评估模型表现\n",
    "from sklearn.metrics import precision_score\n",
    "ps = precision_score(y_test, y_pred, average=None)\n",
    "print('\\nps=', ps, np.average(ps))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用增广的数据进行训练\n",
    "knn_model_A = KNN()\n",
    "knn_model_A.fit(X_train_A, y_train_A)\n",
    "y_pred_A = knn_model_A.predict(X_test.reshape(-1,784))\n",
    "\n",
    "# 评估模型表现\n",
    "psA = precision_score(y_test, y_pred, average=None) \n",
    "print('\\npsA=', psA, np.average(psA))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
